import torch
import torch.nn as nn

rnn = nn.RNN(5, 6, 1)    # 第一个参数：输入张量的维度大小； 第二个参数：隐藏层的维度大小； 第三个参数：隐藏层的数量
input001 = torch.randn(1, 3, 5)  #① 序列长度为1； ② batchsize为3； ③ 输入张量的维度为5
h0 = torch.randn(1, 3, 6)  #① 与rnn的第三个参数对应； ② batchsize； ③ 隐藏层的维度大小

output001, h1 = rnn(input001,h0)
print(output001)
print("----------------------")
print(h1)






